import numpy as np
import pylab as pl
import matplotlib_defaults
from scipy.optimize import leastsq

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# This script plots individual conductance traces
# during the convergence experiments
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #


# # # # # # # # # # # #
# # # P A R A M S # # #
# # # # # # # # # # # #

# Set filename
path = "data/"
fname = "TR15.npz"

#User-defined input.
fname = raw_input("Enter filename: ")

# Trial indices to plot (min and max)
Smin = 1
Smax = 40

# Arangement on the Figure (num rows, num cols)
plotgrid = (4,5)

# # # # # # # # # # # # # # # # # # #
# # # E N D   O F   P A R A M S # # #
# # # # # # # # # # # # # # # # # # #

# adjust some font sizes
pl.rc('figure', dpi=80)
pl.rc("font", size=7)


# Load data and add to name space
X = np.load(path+fname)
for k,v in X.items():
    globals()[k] = v
param = param.tolist()

# adjust Smax if necessary
Smax = min(Smax, len(param['S']))


# plot function
def draw_subplot(s,ax):
    """Plot trace of trial s to axes ax."""
    idx = (param['S'] == s).nonzero()[0][0]
    fmt = lambda dt,v: ( dt=="int" and ("%d" % v) or ("%.2f" % v) )
    string = [ fmt(dt,param[k][idx]) for k,dt in zip(param_name,param_dtype)]
    string = "(" + ", ".join(string) + ")"
    ax.set_title(string, fontsize=7)
    # conductances
    x = np.arange(num_write)
    y = 1. / write_resist[idx]
    ax.plot(x, y, 'b', lw=0.5)
    ymin,ymax = y.min(),y.max()
    dy = ymax - ymin
    ax.set_ylim(ymin-0.1*dy, ymax+0.1*dy)
    # pulses
    x = (write_pulse[idx] == 1).nonzero()[0]
    y = [ymin] * len(x)
    ax.plot(x, y, 'r.', ms=1)
    ax.yaxis.get_major_formatter().set_powerlimits((0, 1))
    ax.set_xlim(0,num_write)
    # Fitted traces.
    x = np.arange(num_write)
    y = hrun(fitrun[s-1], x)
    ax.plot(x, y, 'r.', ms=1)
    print(fitrun[s-1])


def fit_traces():
    #Key parameters, variables and arrays.
    fitrun = np.array([[0.0]*4]*len(write_resist)) #Caution: assumes all test runs of same length.
    fitrunx = np.array([0.0]*len(write_resist))

    #Define fitting and residual functions.
    hrun = lambda p,x: p[0] * np.exp(- p[1] * (x - p[2])) + p[3] # Fitting model function.
    #hrun = lambda p,x: p[0]*np.log(x - p[1]) + p[2] # Log fit.
    frun = lambda p,x: data - hrun(p,x) # Residual function: actual data - model function.

    #Define initial fitting guess.
    fitguess = np.array([0.0, 1.0, 0.0, 0.000005]) #Array of initial guess parameters: [a, b, c, d] -> a*e^(-b*(x-c)) + d
    #fitguess = np.array([0.0005, 1.0, 0.000005, 0]) #Array of initial guess parameters for log-fit.

    #Sweep test runs.
    for i in range(0, len(write_resist), 1):
        data = 1/np.array(write_resist[i]) #Capture write streak data.
        xdat = range(0, len(write_resist[i]), 1) #Capture x-data (simply event indices).

        #Perform fitting and store results.
        fitrun[i,:] = leastsq(frun, fitguess, args=(np.array(xdat),))[0] #Perform least squares fitting to minimise residual.
        fitrunx[i] = param['Pr'][i]
        print('Run ', str(i+1), ' extrapolated convergence to ', fitrun[i,3]) #Show extrapolated final values -> param. no. 3 of the fitting: constant offset.

    return fitrun, hrun


fitrun, hrun = fit_traces()

figs = []
subplots = []

pl.interactive(False)
ppf = np.prod(plotgrid) # plots per figure
for sn,s in enumerate(xrange(Smin,Smax+1)):
    if sn % ppf == 0:
        fig = pl.figure(figsize=(15,8))
        #fig.canvas.set_window_title("Title formatting: " + ", ".join(param_name))
        figs.append(fig)
        sp = []
        for i in xrange(1,min(ppf,Smax-Smin+1-sn)+1):
            subp = fig.add_subplot(plotgrid[0],plotgrid[1],i)
            subp.set_xticks([0,num_write])
            sp.append(subp)
        subplots.append(sp)
        fig.subplots_adjust(0.03, 0.05, 0.97, 0.95,0.3,0.35)

    subp = subplots[sn / ppf][sn % ppf]    
    draw_subplot(s,subp)


#Create extra figure for bringing up desired plots.
fig2 = pl.figure()   
subp = fig2.add_subplot(1,1,1)
draw_subplot(23,subp) #Plot run 2.1.


pl.interactive(True)
pl.show()

input("Press enter to terminate.")